{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "948708a6-e94e-449a-90fc-43d327d12674",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from os.path import join\n",
    "\n",
    "import xgboost as xgb\n",
    "import numpy as np\n",
    "import dask.dataframe as dd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98089ad9-efdd-459f-8545-85ddcc4dcb46",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abdf1337-e80e-4e15-83b3-056a2a6a01f2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x_train = np.load(join(DATA_PATH, 'pca/x_pca_training_train_split_256.npy'))\n",
    "y_train = dd.read_parquet(join(DATA_PATH, 'train'), columns='cell_type').compute().to_numpy()\n",
    "\n",
    "x_val = np.load(join(DATA_PATH, 'pca/x_pca_training_val_split_256.npy'))\n",
    "y_val = dd.read_parquet(join(DATA_PATH, 'val'), columns='cell_type').compute().to_numpy()\n",
    "\n",
    "class_weights = np.load(join(DATA_PATH, 'class_weights.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac43212b-edf1-4cee-9116-c2d887b5a6f1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class_weights = {i: weight for i, weight in enumerate(np.load(join(DATA_PATH, 'class_weights.npy')))}\n",
    "weights = np.array([class_weights[label] for label in y_train])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72a702fa-3256-409c-b522-4e7869a020ab",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "clf = xgb.XGBClassifier(\n",
    "    tree_method='gpu_hist',\n",
    "    gpu_id=0,\n",
    "    n_estimators=1000,\n",
    "    eta=0.075,\n",
    "    subsample=0.75,\n",
    "    max_depth=10,\n",
    "    n_jobs=20,\n",
    "    early_stopping_rounds=10\n",
    ")\n",
    "clf = clf.fit(\n",
    "    x_train, y_train, sample_weight=weights, \n",
    "    eval_set=[(x_val, y_val)]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8723fffb-786e-48e0-a021-6a5d1e4fdf83",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "clf.save_model('model.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f01dc191-6f72-4d8f-9a7b-6f5bc46f2ad7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
